# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from absl.testing import absltest
from absl.testing import parameterized

import jax
import jax.numpy as jnp
from jax import lax
from jax.config import config
from jax.experimental import checkify
import jax._src.test_util as jtu

config.parse_flags_with_absl()


class CheckifyTransformTests(jtu.JaxTestCase):
  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_jit={}".format(jit), "jit": jit}
      for jit in [False, True]))
  @jtu.skip_on_devices('tpu')
  def test_jit_nan(self, jit):
    def f(x1, x2):
      y1 = jnp.sin(x1)
      y2 = jnp.sin(x2)
      return y1 + y2

    f = jax.jit(f) if jit else f

    err, _ = checkify.checkify(f)(3., 4.)
    self.assertIs(err.get(), None)

    err, _ = checkify.checkify(f)(3., jnp.inf)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive sin')

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_jit={}".format(jit), "jit": jit}
      for jit in [False, True]))
  def test_jit_oob(self, jit):
    def f(x, i):
      y = jnp.sin(x)
      z = y[i]
      w = jnp.cos(z)
      return w

    f = jax.jit(f) if jit else f

    err, _ = checkify.checkify(f)(jnp.arange(3), 2)
    self.assertIs(err.get(), None)

    err, _ = checkify.checkify(f)(jnp.arange(3), 5)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'out-of-bounds indexing')

  @parameterized.named_parameters(
      {"testcase_name": f"_update={update_fn}", "update_fn": update_fn}
      for update_fn in ["set", "add", "multiply", "divide", "power", "min",
                        "max", "get"])
  def test_jit_oob_update(self, update_fn):
    def f(x, i):
      return getattr(x.at[i], update_fn)(1.)

    f = jax.jit(f)

    err, _ = checkify.checkify(f)(jnp.arange(3), 2)
    self.assertIs(err.get(), None)

    err, _ = checkify.checkify(f)(jnp.arange(3), 3)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'out-of-bounds indexing')

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_jit={}".format(jit), "jit": jit}
      for jit in [False, True]))
  def test_jit_div_errors(self, jit):
    def f(x, y):
      return x/y

    f = jax.jit(f) if jit else f

    err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
    self.assertIs(err.get(), None)

    err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")

    err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive div')

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_jit={}".format(jit), "jit": jit}
      for jit in [False, True]))
  @jtu.skip_on_devices('tpu')
  def test_jit_multi(self, jit):
    def f(x, i):
      y = x[i]
      z = jnp.cos(y)
      return z

    f = jax.jit(f) if jit else f

    # no error
    err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2)
    self.assertIs(err.get(), None)

    # oob error
    err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'out-of-bounds indexing')

    # nan error
    err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive cos')

  @parameterized.named_parameters(jtu.cases_from_list(
      {"testcase_name": "_jit={}".format(jit), "jit": jit}
      for jit in [False, True]))
  def test_jit_ordering(self, jit):
    def f(x, i):
      y = x[i]
      z = jnp.sin(x)
      return y * z

    f = jax.jit(f) if jit else f

    # both oob and nan error, but oob happens first
    err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'out-of-bounds indexing')

  @jtu.skip_on_devices('tpu')
  def test_pmap_basic(self):
    if len(jax.devices()) < 2:
      raise unittest.SkipTest("requires at least 2 devices")

    @jax.pmap
    def f(x1, x2):
      y1 = jnp.sin(x1)
      y2 = jnp.sin(x2)
      return y1 + y2

    xs = jnp.array([0., 2.])
    err, _ = checkify.checkify(f)(xs, xs)
    self.assertIs(err.get(), None)

    ys = jnp.array([3., jnp.inf])
    err, _ = checkify.checkify(f)(xs, ys)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive sin')

  @jtu.skip_on_devices('tpu')
  def test_cond_basic(self):
    @jax.jit
    def f(x):
      return lax.cond(x > 0,
                      lambda: jnp.sin(x),
                      lambda: x)

    err, y = checkify.checkify(f)(3.)
    self.assertIs(err.get(), None)

    err, y = checkify.checkify(f)(jnp.inf)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), 'nan generated by primitive sin')

    err, y = checkify.checkify(f)(-jnp.inf)
    self.assertIs(err.get(), None)


  @jtu.skip_on_devices('tpu')
  def test_scan_map(self):
    def scan_body(_, x):
      return None, jnp.sin(x)

    @jax.jit
    def f(xs):
      return lax.scan(scan_body, None, xs)

    xs = jnp.array([0., 2.])
    err, (_, ch_outs) = checkify.checkify(f)(xs)
    _, outs = f(xs)
    self.assertIs(err.get(), None)
    self.assertArraysEqual(ch_outs, outs)

    xs = jnp.array([3., jnp.inf])
    err, (_, ch_outs) = checkify.checkify(f)(xs)
    _, outs = f(xs)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive sin")
    self.assertArraysEqual(ch_outs, outs)

  @jtu.skip_on_devices('tpu')
  def test_scan_carry(self):
    def scan_body(carry, x):
      carry = carry-1.
      possible_nan = jnp.sin(1./carry)
      return carry, x+possible_nan

    @jax.jit
    def f(carry, xs):
      return lax.scan(scan_body, carry, xs)

    carry, xs = 3., jnp.ones((2,))
    err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
    out_carry, outs = f(carry, xs)
    self.assertIs(err.get(), None)
    self.assertArraysEqual(ch_outs, outs)
    self.assertArraysEqual(ch_out_carry, out_carry)

    # error happens on first iteration
    carry, xs = 1., jnp.ones((2,))
    err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
    out_carry, outs = f(carry, xs)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")
    self.assertArraysEqual(ch_outs, outs)
    self.assertArraysEqual(ch_out_carry, out_carry)

    # error happens on second iteration
    carry, xs = 2., jnp.ones((4,))
    err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
    out_carry, outs = f(carry, xs)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")
    self.assertArraysEqual(ch_outs, outs)
    self.assertArraysEqual(ch_out_carry, out_carry)

  @jtu.skip_on_devices("tpu")
  def test_while_loop_body_error(self):
    def while_cond(val):
      i, _ = val
      return i < 2

    def while_body(val):
      i, x = val
      possible_nan = jnp.sin(1./i)
      return i+1., x+possible_nan

    @jax.jit
    def f(init_val):
      return lax.while_loop(while_cond, while_body, (init_val, 0.))

    init_val = 1.
    err, ch_out = checkify.checkify(f)(init_val)
    out = f(init_val)
    self.assertIs(err.get(), None)
    self.assertArraysEqual(ch_out, out)

    init_val = 0.
    err, ch_out = checkify.checkify(f)(init_val)
    out = f(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")
    self.assertArraysEqual(ch_out, out)

  @jtu.skip_on_devices("tpu")
  def test_while_loop_cond_error(self):
    def while_cond(val):
      _ = jnp.sin(1./val)
      return val < 2.

    def while_body(val):
      return val+1.

    @jax.jit
    def f(init_val):
      return lax.while_loop(while_cond, while_body, init_val)

    init_val = 1.
    err, ch_out = checkify.checkify(f)(init_val)
    out = f(init_val)
    self.assertIs(err.get(), None)
    self.assertArraysEqual(ch_out, out)

    init_val = 0.
    err, ch_out = checkify.checkify(f)(init_val)
    out = f(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")
    self.assertArraysEqual(ch_out, out)

  @jtu.skip_on_devices("tpu")
  def test_while_loop_cond_error_and_false(self):
    # Tests if an error is generated when cond returns False.
    def while_cond(val):
      possible_nan = jnp.sin(1./val)
      return jnp.logical_not(jnp.isnan(possible_nan))

    @jax.jit
    def f(init_val):
      return lax.while_loop(while_cond, lambda val: val-1, init_val)

    # error on first cond
    init_val = 0.
    err, _ = checkify.checkify(f)(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")

    # error on second cond
    init_val = 1.
    err, _ = checkify.checkify(f)(init_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "divided by zero")

  @jtu.skip_on_devices("tpu")
  def test_while_loop_body_and_cond_error(self):
    def while_cond(val):
      i, cond_val, _ = val
      _ = jnp.sin(cond_val)
      return i < 2

    def while_body(val):
      i, cond_val, body_val = val
      possible_nan = jnp.cos(body_val)
      return i+1., cond_val, possible_nan

    @jax.jit
    def f(cond_val, body_val):
      return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))

    cond_val = jnp.inf
    body_val = 1.
    err, _ = checkify.checkify(f)(cond_val, body_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive sin")

    cond_val = 1.
    body_val = jnp.inf
    err, _ = checkify.checkify(f)(cond_val, body_val)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "nan generated by primitive cos")

    cond_val = jnp.inf
    body_val = jnp.inf
    err, _ = checkify.checkify(f)(cond_val, body_val)
    self.assertIsNotNone(err.get())
    # first error which occurs is in cond
    self.assertStartsWith(err.get(), "nan generated by primitive sin")

class AssertPrimitiveTests(jtu.JaxTestCase):
  def test_assert_primitive_impl(self):
    def f():
      checkify.assert_(False, "hi")

    with self.assertRaisesRegex(AssertionError, "hi"):
      f()

  def test_assert_primitive_(self):
    @jax.jit
    def f():
      checkify.assert_(False, "hi")

    with self.assertRaisesRegex(Exception, "can't be staged"):
      f()

  def test_assert_discharging(self):
    @checkify.checkify
    def f(x):
      checkify.assert_(x > 0, "must be positive!")
      return jnp.log(x)

    err, y = f(1.)
    self.assertIsNone(err.get())

    err, y = f(0.)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "must be positive")

    f = jax.jit(f)

    err, y = f(1.)
    self.assertIsNone(err.get())

    err, y = f(0.)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "must be positive")

  def test_assert2(self):
    def f(pred):  # note: data dependence needed!
      checkify.assert2_(pred, 0, {0: "hi"})

    with self.assertRaisesRegex(AssertionError, "hi"):
      f(False)

    f = checkify.checkify(f)
    err, none = f(False)

    self.assertIsNone(none)
    self.assertIsNotNone(err.get())
    self.assertStartsWith(err.get(), "hi")

  def test_discharge_recharge(self):
    def ejit(f):
      f = checkify.checkify(f)
      f = jax.jit(f)
      def jitted_f(*args):
        err, out = f(*args)
        checkify.assert2_(~err.err, err.code, err.msgs)
        return out
      return jitted_f

    @ejit
    def f(pred):
      assert python_should_be_running
      checkify.assert_(pred, "foo")

    python_should_be_running = True
    f(True)

    python_should_be_running = False
    f(True)
    with self.assertRaisesRegex(AssertionError, "foo"):
      f(False)


if __name__ == "__main__":
  absltest.main(testLoader=jtu.JaxTestLoader())
